{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer learning gender classifier by using dlib face feature extractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import timeit\n",
    "import cv2\n",
    "from skimage import io as io\n",
    "import face_recognition as fr\n",
    "import numpy as np\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "from sklearn import datasets, svm, metrics\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "import random\n",
    "from matplotlib import pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('celeba_smiling_data.pkl', 'rb') as f:\n",
    "    gender_dataset_raw = pickle.load(f)\n",
    "random.shuffle(gender_dataset_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4865"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(gender_dataset_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "length of embedding train list: 4000\n",
      "lenght of label train list: 4000\n",
      "length of embedding test list: 865\n",
      "lenght of label test list: 865\n"
     ]
    }
   ],
   "source": [
    "embedding_list_train = list()\n",
    "gender_label_list_train = list()\n",
    "\n",
    "embedding_list_test = list()\n",
    "gender_label_list_test = list()\n",
    "\n",
    "for emb, label in gender_dataset_raw[0:4000]:\n",
    "    embedding_list_train.append(emb)\n",
    "    gender_label_list_train.append(label)\n",
    "\n",
    "for emb, label in gender_dataset_raw[4000:]:\n",
    "    embedding_list_test.append(emb)\n",
    "    gender_label_list_test.append(label)\n",
    "    \n",
    "print('length of embedding train list: {}'.format(len(embedding_list_train)))\n",
    "print('lenght of label train list: {}'.format(len(gender_label_list_train)))\n",
    "print('length of embedding test list: {}'.format(len(embedding_list_test)))\n",
    "print('lenght of label test list: {}'.format(len(gender_label_list_test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM]"
     ]
    }
   ],
   "source": [
    "param_grid = {'C': [1e1, 1e2],\n",
    "              'gamma': [10, 100], \n",
    "              'decision_function_shape':[\"ovo\"],\n",
    "                \"degree\":[3,4]}\n",
    "classifier = GridSearchCV(svm.SVC(kernel='rbf', class_weight=\"balanced\", verbose=True),\n",
    "                   param_grid, cv=5)\n",
    "# classifier = svm.SVC(gamma='auto', kernel='rbf', C=20, degree=3, class_weight=\"balanced\", \n",
    "#                     verbose=True)\n",
    "classifier.fit(embedding_list_train, gender_label_list_train)\n",
    "\n",
    "expected = gender_label_list_test\n",
    "predicted = classifier.predict(embedding_list_test)\n",
    "\n",
    "print(\"Classification report for classifier %s:\\n%s\\n\"\n",
    "      % (classifier, metrics.classification_report(expected, predicted)))\n",
    "print(\"Confusion matrix:\\n%s\" % metrics.confusion_matrix(expected, predicted))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM][LibSVM]Classification report for classifier GridSearchCV(cv=5, error_score='raise-deprecating',\n",
    "       estimator=SVC(C=1.0, cache_size=200, class_weight='balanced', coef0=0.0,\n",
    "  decision_function_shape='ovr', degree=3, gamma='auto_deprecated',\n",
    "  kernel='rbf', max_iter=-1, probability=False, random_state=None,\n",
    "  shrinking=True, tol=0.001, verbose=True),\n",
    "       fit_params=None, iid='warn', n_jobs=None,\n",
    "       param_grid={'degree': [3, 4], 'C': [10.0, 100.0], 'decision_function_shape': ['ovo'], 'gamma': [10, 100]},\n",
    "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
    "       scoring=None, verbose=0):\n",
    "              precision    recall  f1-score   support\n",
    "\n",
    "          -1       0.68      0.73      0.70       446\n",
    "           1       0.69      0.63      0.66       419\n",
    "\n",
    "   micro avg       0.68      0.68      0.68       865\n",
    "   macro avg       0.68      0.68      0.68       865\n",
    "weighted avg       0.68      0.68      0.68       865\n",
    "\n",
    "\n",
    "Confusion matrix:\n",
    "[[326 120]\n",
    " [155 264]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier.best_estimator_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "test_img_path = \"/home/ubuntu/ihandy_seg/data/img_align_celeba\"\n",
    "test_imgs = [f for f in os.listdir(test_img_path) if not f.startswith('.')]\n",
    "table = pd.read_csv(\"/home/ubuntu/ihandy_seg/data/list_attr_celeba.csv\")\n",
    "img_label = table.set_index('image_id').to_dict()[\"Male\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_count = 0\n",
    "test_count = 0\n",
    "for i, img_file in enumerate(test_imgs[:100]):\n",
    "    img = io.imread(os.path.join(test_img_path, img_file))\n",
    "    try:\n",
    "        img_enc = fr.face_encodings(img)[0]\n",
    "    except:\n",
    "        print(\"None Face in this image:\", img_file)\n",
    "        continue\n",
    "    test_count += 1\n",
    "    img_predict = classifier.predict([img_enc])\n",
    "    if img_predict[0] == 1 and img_label[img_file] == 1:\n",
    "        correct_count += 1\n",
    "    if img_predict[0] == -1 and img_label[img_file] == -1:\n",
    "        correct_count += 1\n",
    "    plt.title(\"The prediced gender: {}\".format(img_predict))\n",
    "    plt.title(img_predict)\n",
    "    plt.imshow(img)\n",
    "    plt.show()\n",
    "print(\"accuracy:\", correct_count / test_count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img1 = io.imread('test1.jpg')\n",
    "img2 = io.imread('test2.jpg')\n",
    "\n",
    "img1_enc = fr.face_encodings(img1)[0]\n",
    "img2_enc = fr.face_encodings(img2)[0]\n",
    "\n",
    "img1_predict = classifier.predict([img1_enc])\n",
    "img2_predict = classifier.predict([img2_enc])\n",
    "\n",
    "plt.title('The predicted gender: {}'.format(img1_predict))\n",
    "plt.imshow(img1)\n",
    "plt.show()\n",
    "\n",
    "plt.title('The predicted gender: {}'.format(img2_predict))\n",
    "plt.imshow(img2)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
